@@ -52,6 +52,7 @@ using GPUDevice = Eigen::GpuDevice;
52
52
namespace {
53
53
const int64 kEmbeddingVarUseDB = -214 ;
54
54
const int64 kInitializableEmbeddingVarUseDB = -215 ;
55
+ const char * kInferenceMode = " INFERENCE_MODE" ;
55
56
}
56
57
57
58
#define REGISTER_KV_VAR_HANDLE (ktype, vtype ) \
@@ -370,6 +371,10 @@ template <typename TKey, typename TValue>
370
371
class KvResourceGatherOp : public OpKernel {
371
372
public:
372
373
explicit KvResourceGatherOp (OpKernelConstruction* c) : OpKernel(c) {
374
+ OP_REQUIRES_OK (c, c->GetAttr (" is_inference" , &is_inference_));
375
+ bool is_inference;
376
+ TF_CHECK_OK (ReadBoolFromEnvVar (kInferenceMode , false , &is_inference));
377
+ is_inference_ |= is_inference;
373
378
OP_REQUIRES_OK (c,
374
379
c->GetAttr (" is_use_default_value_tensor" ,
375
380
&is_use_default_value_tensor_));
@@ -393,6 +398,17 @@ class KvResourceGatherOp : public OpKernel {
393
398
return 1 ;
394
399
};
395
400
}
401
+ if (!is_inference_) {
402
+ lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
403
+ TValue* val, TValue* default_v, int count) {
404
+ ev->LookupOrCreate (key, val, default_v, count);
405
+ };
406
+ } else {
407
+ lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
408
+ TValue* val, TValue* default_v, int count) {
409
+ ev->Lookup (key, val, default_v);
410
+ };
411
+ }
396
412
}
397
413
398
414
void Compute (OpKernelContext* c) override {
@@ -443,7 +459,7 @@ class KvResourceGatherOp : public OpKernel {
443
459
default_v, indices_flat (i), i, ev->GetDefaultValueDim (),
444
460
ev->ValueLen ());
445
461
int32 count = get_count_fn_ (counts, i);
446
- ev-> LookupOrCreate ( indices_flat (i),
462
+ lookup_fn_ (ev, indices_flat (i),
447
463
out_base + i * slice_elems, default_v_ptr, count);
448
464
}
449
465
};
@@ -463,9 +479,12 @@ class KvResourceGatherOp : public OpKernel {
463
479
464
480
private:
465
481
bool is_use_default_value_tensor_;
482
+ bool is_inference_;
466
483
std::function<
467
484
TValue*(TValue*, TKey, int64, int64, int64)> get_default_v_fn_;
468
485
std::function<int32(int32*, int64)> get_count_fn_;
486
+ std::function<void (EmbeddingVar<TKey, TValue>* ev,
487
+ TKey key, TValue* val, TValue* default_v, int count)> lookup_fn_;
469
488
};
470
489
471
490
#define REGISTER_GATHER_FULL (dev, ktype, vtype ) \
0 commit comments