@@ -50,6 +50,7 @@ namespace tensorflow {
50
50
namespace {
51
51
const int64 kEmbeddingVarUseDB = -214 ;
52
52
const int64 kInitializableEmbeddingVarUseDB = -215 ;
53
+ const char * kInferenceMode = " INFERENCE_MODE" ;
53
54
}
54
55
55
56
#define REGISTER_KV_VAR_HANDLE (ktype, vtype ) \
@@ -438,6 +439,10 @@ template <typename TKey, typename TValue>
438
439
class KvResourceGatherOp : public OpKernel {
439
440
public:
440
441
explicit KvResourceGatherOp (OpKernelConstruction* c) : OpKernel(c) {
442
+ OP_REQUIRES_OK (c, c->GetAttr (" is_inference" , &is_inference_));
443
+ bool is_inference;
444
+ TF_CHECK_OK (ReadBoolFromEnvVar (kInferenceMode , false , &is_inference));
445
+ is_inference_ |= is_inference;
441
446
OP_REQUIRES_OK (c,
442
447
c->GetAttr (" is_use_default_value_tensor" ,
443
448
&is_use_default_value_tensor_));
@@ -461,6 +466,17 @@ class KvResourceGatherOp : public OpKernel {
461
466
return 1 ;
462
467
};
463
468
}
469
+ if (!is_inference_) {
470
+ lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
471
+ TValue* val, TValue* default_v, int count) {
472
+ ev->LookupOrCreate (key, val, default_v, count);
473
+ };
474
+ } else {
475
+ lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
476
+ TValue* val, TValue* default_v, int count) {
477
+ ev->Lookup (key, val, default_v);
478
+ };
479
+ }
464
480
}
465
481
466
482
void Compute (OpKernelContext* c) override {
@@ -511,7 +527,7 @@ class KvResourceGatherOp : public OpKernel {
511
527
default_v, indices_flat (i), i, ev->GetDefaultValueDim (),
512
528
ev->ValueLen ());
513
529
int32 count = get_count_fn_ (counts, i);
514
- ev-> LookupOrCreate ( indices_flat (i),
530
+ lookup_fn_ (ev, indices_flat (i),
515
531
out_base + i * slice_elems, default_v_ptr, count);
516
532
}
517
533
};
@@ -530,9 +546,12 @@ class KvResourceGatherOp : public OpKernel {
530
546
531
547
private:
532
548
bool is_use_default_value_tensor_;
549
+ bool is_inference_;
533
550
std::function<
534
551
TValue*(TValue*, TKey, int64, int64, int64)> get_default_v_fn_;
535
552
std::function<int32(int32*, int64)> get_count_fn_;
553
+ std::function<void (EmbeddingVar<TKey, TValue>* ev,
554
+ TKey key, TValue* val, TValue* default_v, int count)> lookup_fn_;
536
555
};
537
556
538
557
#define REGISTER_GATHER_FULL (dev, ktype, vtype ) \
0 commit comments