Text embedding different in elastic search and python library

Deploying sentence transformer model as provided in this blog, using eland on Elasticsearch returns only the embedding corresponding to the first token. However, while using the python implementation to encode the text returns the pooled mean embeddings.

embeddings for paytm by python library is average of this token embeddings as a single array:

tensor([[-0.5932,  0.2360,  0.9183,  ..., -0.6631, -0.1304,  0.5950],
    [-0.6218, -0.5103,  0.9183,  ..., -0.3577,  0.1628,  0.5489],
    [-0.7060,  0.8069,  0.5988,  ..., -0.8209, -0.0161,  0.3651],
    [-0.5804, -0.1383,  1.1632,  ..., -0.4210, -0.2995,  0.3780]])

embeddings for paytm by Elasticsearch:

{
  "inference_results": [
    {
      "predicted_value": [
        0.22763113677501678,
        -0.14426757395267487,
        -0.5313942432403564,
        0.38780003786087036,
        0.23178794980049133,
        0.2520352303981781,
        -0.1310071498155594,
        0.1372310221195221,
        0.8867185115814209,
        0.28502148389816284,
        -0.34933239221572876,
        -0.048524804413318634,
        -0.030298544093966484,
        0.3023957908153534,
        0.16458909213542938,
        -0.22286038100719452,
        -0.4322550594806671,
        -0.17900292575359344,
        0.41147494316101074,
        0.5585047602653503,
        0.3104747235774994,
        -0.6178449392318726,
        -0.10300447046756744,
        -0.21040283143520355,
        -0.18054519593715668,
        0.37258920073509216,
        -0.06892939656972885,
        0.4433431029319763,
        -0.49753281474113464,
        -0.6661249995231628,
        0.4594205617904663,
        0.1055305227637291,
        -0.04052606597542763,
        0.6544049382209778,
        -0.00035589560866355896,
        0.5667051672935486,
        -0.14461682736873627,
        0.3525410294532776,
        -0.05935198441147804,
        -0.11387929320335388,
        0.04457880184054375,
        -0.19744209945201874,
        -0.2772529125213623,
        -0.21505369246006012,
        0.07984670996665955,
        -0.5984102487564087,
        0.648749053478241,
        0.03957577049732208,
        -0.2774036228656769,
        -0.43836915493011475,
        -0.13032186031341553,
        -0.04730977490544319,
        0.13484464585781097,
        0.38720041513442993,
        0.3711908161640167,
        -0.6393144726753235,
        0.5307539105415344,
        -0.4074456989765167,
        -0.17671692371368408,
        0.04556211456656456,
        -0.4378435015678406,
        -0.5093674063682556,
        -0.41535666584968567,
        -0.14222007989883423,
        0.32243072986602783,
        -0.035575903952121735,
        -0.8241013288497925,
        0.35598224401474,
        -0.1930820792913437,
        0.10115426033735275,
        -0.18606221675872803,
        -0.2894442677497864,
        0.014790929853916168,
        -0.46368759870529175,
        -0.1974625289440155,
        0.2996467649936676,
        0.13894493877887726,
        -0.13177761435508728,
        0.210512176156044,
        -0.07767431437969208,
        -0.1655813753604889,
        0.451304167509079,
        0.5005711913108826,
        -0.03761318698525429,
        0.13771039247512817,
        0.4428074359893799,
        0.3595246970653534,
        0.10929585248231888,
        0.39517447352409363,
        -0.12278398871421814,
        0.40418583154678345,
        0.2836850881576538,
        0.7709897756576538,
        0.35056984424591064,
        0.5010964870452881,
        -0.2106258124113083,
        0.3735590875148773,
        -0.2123080939054489,
        -0.23633287847042084,
        4.190859794616699,
        0.9373717904090881,
        0.26693618297576904,
        -0.4606470763683319,
        0.16023531556129456,
        0.11719824373722076,
        -0.1723480224609375,
        0.001891970168799162,
        0.32362309098243713,
        -0.2978127896785736,
        0.8658319711685181,
        0.12660570442676544,
        0.07594136148691177,
        0.1375199258327484,
        0.35534992814064026,
        -0.5539746284484863,
        0.15599387884140015,
        -0.13606874644756317,
        -0.47219544649124146,
        -0.11688024550676346,
        0.04422470182180405,
        -0.5628999471664429,
        0.5157653093338013,
        0.0003223186358809471,
        -0.7226400375366211,
        -0.230352520942688,
        -0.8993479013442993,
        0.15943549573421478,
        -0.14756599068641663,
        -0.12611696124076843,
        -0.05348636955022812,
        -0.06012176722288132,
        -0.3392755389213562,
        -0.670730710029602,
        -0.19131512939929962,
        0.540212094783783,
        0.4702330231666565,
        0.3174777030944824,
        0.10704193264245987,
        0.27343031764030457,
        -0.024485638365149498,
        0.29616883397102356,
        -0.2712949812412262,
        -0.40323683619499207,
        -0.1152970939874649,
        -0.6109105944633484,
        0.0476105697453022,
        0.3698404133319855,
        0.10193509608507156,
        0.20707038044929504,
        0.08452828228473663,
        -0.5184361934661865,
        -0.3315282464027405,
        -0.015834353864192963,
        -0.05709962919354439,
        -0.1773345172405243,
        -0.12852440774440765,
        0.4265013337135315,
        -0.7410590052604675,
        0.18239693343639374,
        0.2447529435157776,
        -0.5534960627555847,
        -0.25594016909599304,
        0.19369451701641083,
        -0.4599858224391937,
        0.39168205857276917,
        -0.3994428813457489,
        -0.03455989435315132,
        -0.7195948362350464,
        -0.05066046491265297,
        -0.649211585521698,
        -0.4032367467880249,
        0.3060555160045624,
        -0.508898913860321,
        0.3905222713947296,
        -0.2485397756099701,
        -0.11877850443124771,
        0.023775644600391388,
        -0.5504740476608276,
        -0.2715173363685608,
        0.03782191500067711,
        0.051315758377313614,
        0.35944855213165283,
        -0.14557619392871857,
        -0.9205848574638367,
        -0.15760433673858643,
        0.03252140060067177,
        0.04119274020195007,
        -0.07748888432979584,
        0.1008782759308815,
        -0.8629116415977478,
        -0.1624283492565155,
        -0.7425805926322937,
        -0.35126468539237976,
        0.10007870942354202,
        -0.33232635259628296,
        -0.051030684262514114,
        0.31561392545700073,
        0.10930152982473373,
        0.4459870159626007,
        0.15022096037864685,
        0.42052289843559265,
        0.49589547514915466,
        0.03847932070493698,
        -0.039301712065935135,
        0.13951952755451202,
        -0.0566532164812088,
        0.12938320636749268,
        0.7096655368804932,
        -0.1356537938117981,
        0.014559656381607056,
        -0.3189936578273773,
        0.0051023634150624275,
        -0.045201703906059265,
        -0.37529733777046204,
        0.047298870980739594,
        -0.29170912504196167,
        -0.16526596248149872,
        -0.33550798892974854,
        -0.25210806727409363,
        -0.20536242425441742,
        -0.13906940817832947,
        0.11860240250825882,
        -0.04340652376413345,
        -0.7233018279075623,
        -0.1331174224615097,
        -0.07978124916553497,
        -0.06738846749067307,
        0.6278455853462219,
        0.5404669642448425,
        -0.13116887211799622,
        0.12284678965806961,
        -0.46880680322647095,
        0.04024381563067436,
        0.5078399777412415,
        0.04804232716560364,
        -0.05504803732037544,
        -0.369282603263855,
        -0.2652692496776581,
        0.12957318127155304,
        0.6614379286766052,
        -0.11627178639173508,
        0.12523607909679413,
        0.5214363932609558,
        0.04885203391313553,
        0.0536777563393116,
        -0.31201857328414917,
        0.07515550404787064,
        0.6545552015304565,
        -0.31832942366600037,
        0.6508542895317078,
        0.0718972235918045,
        -0.618195652961731,
        0.5038871765136719,
        0.008811783976852894,
        0.4164590537548065,
        -0.15855854749679565,
        -0.5688043236732483,
        0.4058866798877716,
        -0.3939337134361267,
        -0.18349924683570862,
        -0.22742922604084015,
        0.11201909184455872,
        -0.27650734782218933,
        -0.10007217526435852,
        -0.07031919807195663,
        -0.04137275367975235,
        0.20231564342975616,
        0.7155093550682068,
        0.881754994392395,
        -0.7824999094009399,
        -0.00576990470290184,
        -0.1560748666524887,
        0.3927711546421051,
        -0.3355693221092224,
        0.21854330599308014,
        -0.12200477719306946,
        -0.389082133769989,
        -0.022180264815688133,
        0.13669267296791077,
        -0.3045346140861511,
        0.6712400317192078,
        0.08850947767496109,
        0.08425739407539368,
        0.036125488579273224,
        0.09147649258375168,
        -0.36862194538116455,
        0.28320005536079407,
        0.5554841756820679,
        -0.26776692271232605,
        -0.32163357734680176,
        0.7347353100776672,
        0.23886728286743164,
        0.1397193968296051,
        0.05582265183329582,
        -0.15992748737335205,
        0.1794859617948532,
        -0.16006292402744293,
        0.3651401400566101,
        0.2118891328573227,
        1.1367347240447998,
        0.29052412509918213,
        -0.6307929158210754,
        -0.24623067677021027,
        -0.006436658091843128,
        0.31232452392578125,
        0.10544918477535248,
        -0.006761192809790373,
        0.32247036695480347,
        -0.3944847583770752,
        0.7537521719932556,
        -0.10560796409845352,
        0.06467892229557037,
        -0.33669328689575195,
        0.2395411729812622,
        -0.20653904974460602,
        -0.0971890389919281,
        -0.03428728133440018,
        0.03297343850135803,
        -0.1080745980143547,
        -1.0874756574630737,
        -0.18504175543785095,
        -0.5369579792022705,
        0.5190311670303345,
        0.26966139674186707,
        -0.35968539118766785,
        0.060317639261484146,
        -0.17809531092643738,
        0.13922858238220215,
        0.356803834438324,
        0.08462706953287125,
        0.21717974543571472,
        -0.22614267468452454,
        0.47436395287513733,
        0.7518540024757385,
        0.09135100990533829,
        -0.007415523752570152,
        0.12765496969223022,
        0.2695424258708954,
        0.08344051241874695,
        -0.3245477080345154,
        0.1493050903081894,
        0.2552494704723358,
        0.3694363534450531,
        -0.2119334638118744,
        -0.5604385733604431,
        0.11661115288734436,
        -0.5152119994163513,
        0.5081479549407959,
        -0.017947545275092125,
        -0.25604650378227234,
        -0.45127663016319275,
        0.05397222563624382,
        -0.043777838349342346,
        -0.29821574687957764,
        -0.6421034336090088,
        0.0829707607626915,
        0.7568528652191162,
        -0.39516523480415344,
        -0.6443259119987488,
        0.1516796052455902,
        -0.6404282450675964,
        -0.5156242251396179,
        -0.4068266451358795,
        -0.4147300720214844,
        -0.028562845662236214,
        0.09917867183685303,
        -0.30183061957359314,
        0.23182949423789978,
        -0.1346864402294159,
        -0.4248684048652649,
        0.7310291528701782,
        -0.08360954374074936,
        0.045938946306705475,
        0.24128252267837524,
        -0.2900591790676117,
        -0.1171850860118866,
        -0.8898353576660156,
        -0.4099249541759491,
        0.5582566857337952,
        -0.3977157175540924,
        -0.12525013089179993,
        -0.13443447649478912,
        0.24326668679714203,
        0.08908563107252121
      ]
    }
  ]
}

I want the embeddings from Elasticsearch same as python library returns. Can someone please help me to understand this.

HI @Roshan_Kumar1

Thanks for opening the issue. Which Sentence Transformer model are you using? and can you share the Python code you used to evaluate the model please.

I'm assuming the blog you referred to is Text Embeddings and Vector Search and the model you are using is msmarco-MiniLM-L-12-v3.

Elasticsearch is returning the embedding for the CLS token which is a representation of the entire sentence. I have compared the outputs for evaluating the model in Python and Elasticsearch and they are the same. Here is the Python code I used:

from transformers import AutoTokenizer, AutoModel
import torch

# Load model from HuggingFace
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/msmarco-MiniLM-L-12-v3')
model = AutoModel.from_pretrained('sentence-transformers/msmarco-MiniLM-L-12-v3')

# tokenize
encoded_input = tokenizer("CLS is the first token in the output", padding=True, truncation=True, return_tensors='pt')

# evaluate
with torch.no_grad():
    model_output = model(**encoded_input)


print("CLS:")
print(model_output[0][0][0])

Hi @dkyle,

Thanks for your reply,

the python code we are using is

from sentence_transformers import SentenceTransformer
query = 'paytm'
model = SentenceTransformer('sentence-transformers/msmarco-MiniLM-L-12-v3')
query_embedding = model.encode(query)
print(query_embedding)

and this code is returning different embedding. I assume this is because the encode method of the sentence_transformers uses a mean pooling layer and returns the mean embedding of all the tokens.

How can we add the mean pooling layer in Elasticsearch.

Python code embedding outpu:

[ 1.57020450e-01 -3.04969966e-01 -4.18578237e-01  1.95406750e-01
  2.02006668e-01  5.58146685e-02 -4.38398391e-01 -2.81178951e-01
  4.15557325e-01 -4.18006293e-02 -2.92023152e-01  7.58312568e-02
 -1.70991749e-01  4.99069467e-02  1.76464975e-01 -2.08336204e-01
 -1.18595846e-01 -5.65242410e-01  7.41798580e-01  4.85240102e-01
 -1.64378211e-01 -3.07181358e-01  1.01713732e-01  3.40564810e-02
 -4.94990259e-01  1.66547596e-01 -1.66534334e-01  1.60300225e-01
 -1.57258272e-01 -3.07429552e-01  6.33597225e-02  2.42984951e-01
 -1.99134294e-02  3.79665017e-01 -1.82920009e-01  3.93111229e-01
  6.19311519e-02  2.85843480e-02 -1.39212325e-01 -1.68606360e-03
  6.56508356e-02  4.86712195e-02 -1.47408381e-01 -1.80367991e-01
 -1.93030730e-01 -1.99212745e-01  4.15605605e-01 -3.52042876e-02
 -3.29515457e-01 -9.90670472e-02 -2.08950937e-01 -9.69196558e-02
  2.31286824e-01  1.99292421e-01  1.24570832e-01 -7.73928940e-01
  3.26086581e-01 -6.66038394e-02 -4.74421978e-01  2.32767582e-01
 -4.48115587e-01 -3.22077096e-01 -2.23321900e-01 -2.01462433e-01
  6.26845241e-01 -2.97835171e-01 -5.71593165e-01  3.02757442e-01
  5.24613634e-03 -1.31223679e-01 -3.09483498e-01 -2.47762471e-01
 -8.12271461e-02 -1.26609564e-01 -1.24954954e-01  1.99522004e-01
  5.32756299e-02  9.06654447e-02  2.14736640e-01 -2.31381148e-01
 -3.65558594e-01  1.29730389e-01  4.58975941e-01  2.17153445e-01
 -8.90894327e-03  5.44396579e-01  2.20549181e-01  8.79131854e-02
 -6.92170858e-02 -1.26884490e-01 -9.04388055e-02  1.23395510e-01
  2.72564471e-01  2.65780658e-01  2.99320012e-01  2.14289397e-01
  7.39321560e-02  1.63154185e-01 -1.35010242e-01  1.84587884e+00
  7.73784161e-01  1.76078543e-01 -2.75482059e-01 -1.29865363e-01
  4.46600258e-01 -4.62789178e-01 -1.34769842e-01  3.19547892e-01
 -5.65379597e-02  2.82305449e-01  2.12051079e-01  4.23187800e-02
 -1.97578799e-02 -2.86127329e-02 -1.50137886e-01 -1.93763256e-01
 -3.22995812e-01 -2.61391718e-02 -1.26339540e-01 -1.05652116e-01
 -4.59945560e-01  3.26163024e-01 -6.90251961e-02 -6.56406164e-01
 -7.31631368e-03 -5.77226281e-01 -5.31393662e-02  2.41409793e-01
  1.95834175e-01  2.08265726e-02  2.57992476e-01 -1.02691181e-01
 -4.55640197e-01 -1.07200548e-01  3.46844792e-01  1.93741232e-01
  6.17668591e-02  5.74026167e-01  4.01071548e-01 -2.61786968e-01
  7.16343671e-02  4.97496605e-01  1.13690913e-01 -4.30343896e-02
 -3.00275445e-01 -8.53034854e-03  2.78470397e-01  2.61030167e-01
  2.79544480e-02 -1.75255254e-01 -2.32070953e-01 -1.67083055e-01
  1.67819858e-01  4.12707150e-01 -4.08259630e-02  2.51850873e-01
 -2.62774169e-01 -5.12763262e-01 -3.82456854e-02  2.18899012e-01
  7.41210654e-02 -2.32950717e-01  1.94279507e-01 -7.13804841e-01
  2.50143409e-01 -1.48955956e-01 -2.23570973e-01 -4.02471393e-01
  3.34638238e-01 -4.02647406e-01 -3.42453271e-01 -1.76552385e-01
 -3.68827909e-01  3.81465614e-01 -5.35965383e-01 -2.01853126e-01
  1.29532382e-01 -4.25518423e-01  6.84260353e-02  1.67022616e-01
  1.48450226e-01  9.26124007e-02 -3.91401723e-03 -4.56557214e-01
 -2.07771540e-01 -5.74990287e-02  5.40614367e-01 -1.25176832e-01
  1.02945291e-01 -7.64323831e-01  1.47866517e-01 -3.19176078e-01
 -1.15690842e-01 -5.21897711e-02 -1.35031670e-01 -2.64891386e-02
 -1.37545079e-01  1.51047215e-01  1.12956248e-01 -5.44001535e-02
  6.90090775e-01  2.56273896e-01  9.92183536e-02  2.52858214e-02
 -2.08148986e-01  1.26472518e-01  1.75120831e-01  5.34753084e-01
  2.43254393e-01  6.98399022e-02 -3.02109569e-01  4.14620936e-02
  1.36226147e-01 -2.07447410e-01 -1.18070401e-01 -3.58627468e-01
  6.00513443e-02 -1.53634474e-01 -4.24920440e-01 -1.41288310e-01
  1.55115083e-01 -2.68099755e-01 -1.98799863e-01 -1.64619356e-01
 -3.23389560e-01 -1.58983558e-01  8.24633539e-02  2.49341011e-01
  4.16259289e-01 -1.12260982e-01  2.87567377e-01 -1.91848487e-01
  2.99432158e-01  8.42712164e-01  4.04469669e-01 -1.39399081e-01
 -3.13903958e-01 -3.01999509e-01  3.76299202e-01  6.45087540e-01
 -1.25392675e-01  9.64604095e-02  3.18872482e-01 -2.35080004e-01
  8.04091617e-02 -6.11730814e-01  1.15417950e-01  2.25882918e-01
 -2.76020262e-02  4.13161963e-01  2.91102350e-01 -3.97978544e-01
  2.60882825e-03  1.56371310e-01  2.31554940e-01  1.78175405e-01
 -5.46089411e-01  2.82149255e-01  3.00847720e-02  9.07731503e-02
 -4.63312790e-02  4.72794712e-01 -7.80103449e-03 -1.00215800e-01
  4.52076904e-02  6.87544495e-02 -1.79666489e-01  1.30811930e-02
  3.62107098e-01 -7.60457754e-01 -4.43303525e-01 -4.10377905e-02
  4.83638011e-02  3.26310582e-02  1.93693250e-01 -2.20095180e-02
  1.91959813e-01 -1.88390493e-01 -1.47601198e-02 -2.05618232e-01
  3.07976127e-01 -6.78170249e-02  4.28609371e-01  2.14761477e-02
 -4.11771834e-02 -3.80589396e-01 -2.61771560e-01  3.17894965e-01
 -2.55945742e-01 -8.90935734e-02  3.16597641e-01  2.82236576e-01
 -5.44986203e-02  1.44175813e-01 -5.15609011e-02  6.39448762e-02
 -6.27623498e-03  4.38402295e-01 -5.72556481e-02  5.86790919e-01
  6.08319521e-01 -4.85747218e-01 -3.87408853e-01  2.16855437e-01
  6.40264392e-01  4.38031852e-02 -1.03336580e-01  1.35861501e-01
  5.60979508e-02  2.81782985e-01 -1.90072536e-01  2.90505327e-02
 -5.93784571e-01  1.94779053e-01  2.28981730e-02  9.34879556e-02
  2.36952007e-02  3.53503972e-04 -1.57684103e-01 -6.75454497e-01
 -4.16141987e-01 -6.15828037e-01  4.26031053e-01  3.69950384e-02
 -1.31642386e-01 -1.66027844e-02 -1.39612168e-01 -3.36009301e-02
  3.70628953e-01  1.20601773e-01  3.36115479e-01 -1.44246310e-01
  1.15085855e-01  4.26943153e-01  2.50236213e-01 -1.35589272e-01
  4.47095633e-01 -1.77276120e-01 -3.72436792e-01 -4.51673359e-01
  2.95148958e-02  4.94144797e-01  2.96961218e-01 -2.76566565e-01
 -3.02508384e-01  1.47394210e-01 -1.01254903e-01  7.55400598e-01
  1.03015691e-01 -3.36666703e-01 -3.64295065e-01  3.00884098e-01
 -1.23272642e-01 -3.89993519e-01 -1.89144909e-01  3.38528395e-01
  5.98880708e-01 -2.46146291e-01 -3.32669824e-01  6.26381159e-01
 -4.16854843e-02 -4.28673327e-01 -3.84301931e-01 -3.37470025e-01
 -2.57837415e-01 -8.77049863e-02 -3.56373012e-01  1.62674598e-02
  4.41732183e-02 -1.69796079e-01  1.32300556e-01  2.74283558e-01
  3.14207792e-01 -6.11468777e-03 -4.83343810e-01 -3.35471153e-01
 -6.01924777e-01  2.23589420e-01  4.89388943e-01  9.13842991e-02
  9.91458073e-02 -1.13925450e-01  1.66930526e-01 -2.25717425e-01]

Elasticsearch embedding output:

 [
        0.15702076256275177,
        -0.30497047305107117,
        -0.4185786843299866,
        0.19540652632713318,
        0.20200636982917786,
        0.05581476911902428,
        -0.4383981227874756,
        -0.2811791002750397,
        0.41555774211883545,
        -0.041800498962402344,
        -0.29202359914779663,
        0.07583139091730118,
        -0.17099213600158691,
        0.049907222390174866,
        0.17646493017673492,
        -0.20833620429039001,
        -0.11859560012817383,
        -0.5652425289154053,
        0.7417986392974854,
        0.485240638256073,
        -0.1643780618906021,
        -0.30718109011650085,
        0.10171419382095337,
        0.034056372940540314,
        -0.49499064683914185,
        0.16654735803604126,
        -0.16653423011302948,
        0.16029980778694153,
        -0.1572583019733429,
        -0.307429701089859,
        0.06335984915494919,
        0.24298489093780518,
        -0.019913621246814728,
        0.37966465950012207,
        -0.1829199194908142,
        0.39311105012893677,
        0.06193055212497711,
        0.028584511950612068,
        -0.139212504029274,
        -0.00168626569211483,
        0.06565070152282715,
        0.04867163300514221,
        -0.14740845561027527,
        -0.1803678572177887,
        -0.1930307298898697,
        -0.1992124319076538,
        0.41560620069503784,
        -0.03520427644252777,
        -0.3295150399208069,
        -0.09906652569770813,
        -0.20895135402679443,
        -0.09691977500915527,
        0.23128646612167358,
        0.19929206371307373,
        0.12457077205181122,
        -0.7739284038543701,
        0.3260864317417145,
        -0.06660394370555878,
        -0.4744216203689575,
        0.23276813328266144,
        -0.4481157064437866,
        -0.3220774233341217,
        -0.2233225703239441,
        -0.20146265625953674,
        0.6268453598022461,
        -0.2978348433971405,
        -0.5715934038162231,
        0.30275771021842957,
        0.0052460841834545135,
        -0.13122424483299255,
        -0.3094833493232727,
        -0.24776297807693481,
        -0.08122719079256058,
        -0.12660977244377136,
        -0.12495452910661697,
        0.19952230155467987,
        0.053276002407073975,
        0.09066520631313324,
        0.21473687887191772,
        -0.2313809096813202,
        -0.36555853486061096,
        0.12973010540008545,
        0.45897582173347473,
        0.21715323626995087,
        -0.00890918355435133,
        0.5443965792655945,
        0.2205493152141571,
        0.08791373670101166,
        -0.06921732425689697,
        -0.12688428163528442,
        -0.0904388576745987,
        0.12339566648006439,
        0.27256447076797485,
        0.2657807767391205,
        0.2993202805519104,
        0.2142896205186844,
        0.07393170893192291,
        0.163153737783432,
        -0.13501033186912537,
        1.8458795547485352,
        0.7737840414047241,
        0.17607846856117249,
        -0.2754819095134735,
        -0.1298648864030838,
        0.4466000199317932,
        -0.462789386510849,
        -0.1347695291042328,
        0.31954824924468994,
        -0.056537315249443054,
        0.2823052704334259,
        0.21205157041549683,
        0.04231896251440048,
        -0.019757771864533424,
        -0.02861320599913597,
        -0.15013769268989563,
        -0.19376340508460999,
        -0.3229956328868866,
        -0.02613930031657219,
        -0.12633869051933289,
        -0.10565194487571716,
        -0.45994502305984497,
        0.32616275548934937,
        -0.06902498751878738,
        -0.6564062833786011,
        -0.007315844297409058,
        -0.5772258043289185,
        -0.05313921719789505,
        0.2414097934961319,
        0.19583415985107422,
        0.020826343446969986,
        0.25799208879470825,
        -0.10269100964069366,
        -0.4556404650211334,
        -0.10720020532608032,
        0.34684425592422485,
        0.19374054670333862,
        0.06176705285906792,
        0.5740267634391785,
        0.40107157826423645,
        -0.2617875933647156,
        0.07163386046886444,
        0.4974963665008545,
        0.11369092762470245,
        -0.04303419589996338,
        -0.3002757430076599,
        -0.008530613034963608,
        0.278470516204834,
        0.2610301375389099,
        0.02795415371656418,
        -0.1752547174692154,
        -0.23207062482833862,
        -0.16708329319953918,
        0.16782008111476898,
        0.4127070903778076,
        -0.04082632064819336,
        0.2518506944179535,
        -0.26277419924736023,
        -0.5127633810043335,
        -0.03824537992477417,
        0.21889907121658325,
        0.07412082701921463,
        -0.2329506129026413,
        0.19427986443042755,
        -0.7138044834136963,
        0.25014370679855347,
        -0.14895588159561157,
        -0.22357085347175598,
        -0.4024714231491089,
        0.33463814854621887,
        -0.40264731645584106,
        -0.3424539566040039,
        -0.17655247449874878,
        -0.36882784962654114,
        0.38146549463272095,
        -0.535965621471405,
        -0.20185306668281555,
        0.12953223288059235,
        -0.4255184829235077,
        0.06842581927776337,
        0.16702216863632202,
        0.1484498232603073,
        0.09261268377304077,
        -0.003913957625627518,
        -0.45655784010887146,
        -0.20777229964733124,
        -0.05749869346618652,
        0.540614902973175,
        -0.12517639994621277,
        0.10294540971517563,
        -0.7643243074417114,
        0.14786696434020996,
        -0.31917622685432434,
        -0.1156909167766571,
        -0.052189432084560394,
        -0.1350315511226654,
        -0.02648896351456642,
        -0.13754528760910034,
        0.15104718506336212,
        0.11295602470636368,
        -0.05439971014857292,
        0.6900904178619385,
        0.2562732696533203,
        0.0992187112569809,
        0.025286098942160606,
        -0.20814867317676544,
        0.1264726221561432,
        0.17512096464633942,
        0.5347529053688049,
        0.2432541847229004,
        0.06983993947505951,
        -0.30210912227630615,
        0.041461870074272156,
        0.13622574508190155,
        -0.2074476182460785,
        -0.11807098984718323,
        -0.3586277365684509,
        0.06005071848630905,
        -0.15363474190235138,
        -0.42492052912712097,
        -0.1412879079580307,
        0.15511471033096313,
        -0.26809969544410706,
        -0.1987997591495514,
        -0.16461965441703796,
        -0.32338929176330566,
        -0.15898416936397552,
        0.08246305584907532,
        0.24934083223342896,
        0.41625896096229553,
        -0.11226070672273636,
        0.2875669598579407,
        -0.19184821844100952,
        0.299431711435318,
        0.8427122831344604,
        0.404469758272171,
        -0.1393987536430359,
        -0.31390389800071716,
        -0.3019992709159851,
        0.37629878520965576,
        0.6450868844985962,
        -0.125392884016037,
        0.09646060317754745,
        0.31887197494506836,
        -0.235080286860466,
        0.08040882647037506,
        -0.611730694770813,
        0.11541779339313507,
        0.22588294744491577,
        -0.027601953595876694,
        0.4131621718406677,
        0.2911023199558258,
        -0.39797815680503845,
        0.002609340474009514,
        0.1563713252544403,
        0.23155444860458374,
        0.1781758964061737,
        -0.5460891127586365,
        0.28214946389198303,
        0.03008478507399559,
        0.09077396988868713,
        -0.04633083939552307,
        0.47279486060142517,
        -0.007801243104040623,
        -0.10021523386240005,
        0.04520805552601814,
        0.06875470280647278,
        -0.17966699600219727,
        0.013080738484859467,
        0.3621068596839905,
        -0.7604570388793945,
        -0.44330349564552307,
        -0.041038159281015396,
        0.048363782465457916,
        0.03263121098279953,
        0.19369381666183472,
        -0.022009439766407013,
        0.19196072220802307,
        -0.1883908361196518,
        -0.014760453253984451,
        -0.2056182622909546,
        0.30797678232192993,
        -0.06781665980815887,
        0.42861008644104004,
        0.02147628739476204,
        -0.04117707535624504,
        -0.38059002161026,
        -0.2617712616920471,
        0.3178943395614624,
        -0.2559453845024109,
        -0.08909349888563156,
        0.3165978789329529,
        0.2822365164756775,
        -0.0544988177716732,
        0.14417582750320435,
        -0.05156082287430763,
        0.06394510716199875,
        -0.0062770722433924675,
        0.4384019374847412,
        -0.05725526064634323,
        0.5867913365364075,
        0.6083194613456726,
        -0.4857479929924011,
        -0.38740915060043335,
        0.21685516834259033,
        0.6402644515037537,
        0.04380349814891815,
        -0.10333643853664398,
        0.13586175441741943,
        0.056098029017448425,
        0.28178250789642334,
        -0.19007229804992676,
        0.02905029058456421,
        -0.5937849283218384,
        0.1947793960571289,
        0.022898219525814056,
        0.09348764270544052,
        0.023695774376392365,
        0.00035322830080986023,
        -0.15768417716026306,
        -0.6754546165466309,
        -0.41614240407943726,
        -0.6158281564712524,
        0.4260311722755432,
        0.03699488192796707,
        -0.131642684340477,
        -0.01660257577896118,
        -0.13961206376552582,
        -0.03360098972916603,
        0.37062859535217285,
        0.12060171365737915,
        0.3361155688762665,
        -0.14424622058868408,
        0.11508522927761078,
        0.4269430339336395,
        0.2502363324165344,
        -0.13558876514434814,
        0.44709524512290955,
        -0.17727646231651306,
        -0.3724367916584015,
        -0.4516729712486267,
        0.029515277594327927,
        0.4941454231739044,
        0.2969611883163452,
        -0.2765662968158722,
        -0.30250829458236694,
        0.1473938226699829,
        -0.10125508159399033,
        0.7554014921188354,
        0.10301551222801208,
        -0.3366667628288269,
        -0.3642953336238861,
        0.3008842170238495,
        -0.12327177822589874,
        -0.38999369740486145,
        -0.18914549052715302,
        0.33852851390838623,
        0.5988801717758179,
        -0.24614626169204712,
        -0.3326700031757355,
        0.6263812780380249,
        -0.04168548062443733,
        -0.42867356538772583,
        -0.3843013346195221,
        -0.337470144033432,
        -0.2578378915786743,
        -0.08770501613616943,
        -0.3563733696937561,
        0.01626719906926155,
        0.04417328163981438,
        -0.1697964072227478,
        0.13230091333389282,
        0.27428460121154785,
        0.314208060503006,
        -0.0061147138476371765,
        -0.4833439886569977,
        -0.3354712426662445,
        -0.6019247174263,
        0.2235894799232483,
        0.489388644695282,
        0.09138386696577072,
        0.09914615005254745,
        -0.11392566561698914,
        0.16693009436130524,
        -0.22571784257888794
      ]

How can we add the mean pooling layer in Elasticsearch.

You have 2 options both of which can be done with a few lines of Python code in Eland

The first option is to configure the model in Elasticsearch with the special pass_through task type. pass_through returns the entire model model output directly. You can take that output and calculate the mean pooling from it.

To do this you need this small change in Eland then upload the model with this command:

eland_import_hub_model \
      --url XXX \
      -u XXX -p XXXX \
      --hub-model-id "sentence-transformers/msmarco-MiniLM-L-12-v3" \
      --task-type pass_through 

The second option is to wrap the Sentence transformer model and add the pooling function. I think this is the best option.

Once again we must add some code to Eland and then build it locally. Go to transformers.py and add the following at line 269

class _SentenceTransformerPooler(nn.Module):  # type: ignore
    def __init__(self, model: transformers.PreTrainedModel):
        super().__init__()
        self._model = model
        self.config = model.config    

    def mean_pooling(self, model_output: Tensor, attention_mask: Tensor) -> Tensor:
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def forward(
        self,
        input_ids: Tensor,
        attention_mask: Tensor,
        token_type_ids: Tensor,
        position_ids: Tensor,
    ) -> Tensor:
        """Wrap the input and output to conform to the native process interface."""

        inputs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
            "position_ids": position_ids,
        }

        model_output = self._model(**inputs)
        return self.mean_pooling(model_output, attention_mask)    

I've taken the mean_pooling function from sentence-transformers/msmarco-MiniLM-L-12-v3 · Hugging Face

Then in the function _create_traceable_model we will hack the code to use the new class, I'm sure you can come up with a more elegant solution than this hack.

Find the elif branch

        elif self._task_type == "text_embedding":
            model = _SentenceTransformerWrapperModule.from_pretrained(self._model_id)
            if not model:
                model = _DPREncoderWrapper.from_pretrained(self._model_id)
            if not model:
                model = transformers.AutoModel.from_pretrained(
                    self._model_id, torchscript=True
                )
            return _TraceableTextEmbeddingModel(self._tokenizer, model)

and replace it with

        elif self._task_type == "text_embedding":
            model = transformers.AutoModel.from_pretrained(self._model_id, torchscript=True)
            return _TraceableTextEmbeddingModel(self._tokenizer, _SentenceTransformerPooler(model))            

Upload the model as usual with the eland_import_hub_model and the result in Elasticsearch will now be the mean pooling.

This topic was automatically closed 28 days after the last reply. New replies are no longer allowed.

The outputs are the same, except that Python has formatted the float values differently. All sentence-transformers maintain their pooling layer when imported to Elasticsearch already.